{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Using and extending the course PyTorch models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Spring 2023\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [Set-up](#Set-up)\n",
    "1. [General optimization choices](#General-optimization-choices)\n",
    "1. [Classifiers](#Classifiers)\n",
    "    1. [Softmax classifier](#Softmax-classifier)\n",
    "    1. [A deeper neural classifier](#A-deeper-neural-classifier)\n",
    "1. [Regression](#Regression)\n",
    "    1. [Linear regression](#Linear-regression)\n",
    "    1. [Deeper Linear Regression](#Deeper-Linear-Regression)\n",
    "1. [RNN sequence labeling](#RNN-sequence-labeling)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "This repository contains a number of PyTorch modules designed to support our core content and provide tools for homeworks and bake-offs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch_autoencoder.py                 torch_model_base.py\n",
      "torch_deep_neural_classifier.py      torch_rnn_classifier.py\n",
      "torch_deep_neural_classifier_iit.py  torch_shallow_neural_classifier.py\n",
      "torch_glove.py\n"
     ]
    }
   ],
   "source": [
    "%ls torch*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The goal of the current notebook is to provide some guidance on how you can extend these modules to create original custom systems. Once you get used to how the code is structured, this is sure to be much faster than coding from scratch, and it still allows you a lot of freedom to design new models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The base class for all the modules is `torch_model_base.TorchModelBase`. The central role of this class is to provide a very full-featured `fit` method. See [General optimization choices](#General-optimization-choices) for an overview of the knobs and levers it provides. The interface is generic enough to accommodate a wide range of tasks.\n",
    "\n",
    "In what follows, we consider three kinds of extension, aiming to highlight general techniques and code patterns:\n",
    "\n",
    "* __Classifiers__: subclasses using `torch_shallow_neural_classifier.py`\n",
    "* __Regressors__: subclasses using `torch_model_base.py`\n",
    "* __RNN-based models__: subclasses using `torch_rnn_classifier.py`\n",
    "\n",
    "If you are experienced with PyTorch already, you can probably dive right into this notebook. If not, then I recommend [our PyTorch tutorial notebook](tutorial_pytorch.ipynb) to start."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "from sklearn.datasets import load_iris, fetch_california_housing\n",
    "from sklearn.metrics import classification_report, r2_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.model_selection import cross_validate\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from torch_model_base import TorchModelBase\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
    "from torch_rnn_classifier import TorchRNNDataset, TorchRNNClassifier, TorchRNNModel\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.fix_random_seeds()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General optimization choices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `TorchModelBase` has a number of keyword parameters that relate to how models are optimized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['batch_size',\n",
       " 'max_iter',\n",
       " 'eta',\n",
       " 'optimizer_class',\n",
       " 'l2_strength',\n",
       " 'gradient_accumulation_steps',\n",
       " 'max_grad_norm',\n",
       " 'validation_fraction',\n",
       " 'early_stopping',\n",
       " 'n_iter_no_change',\n",
       " 'warm_start',\n",
       " 'tol']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TorchModelBase().params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For descriptions of what these parameters do, please refer to the docstring for the class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All of these parameters can be included in hyperparameter optimization runs using tools in `sklearn.model_selection`, as we'll see below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classifiers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To create new classifiers, one typically just needs to subclass `TorchShallowNeuralClassifier` and write a new `build_graph` method to define your computation graph. Here we illustrate with some representative examples, using the [Iris plants dataset](https://scikit-learn.org/stable/datasets/index.html#iris-dataset) for evaluations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iris_split():\n",
    "    dataset = load_iris()\n",
    "    X = dataset.data\n",
    "    y = dataset.target\n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        X, y, test_size=0.33, random_state=42)\n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cls_train, X_cls_test, y_cls_train, y_cls_test = iris_split()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a softmax classifier, we just need to write a simple `build_graph` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchSoftmaxClassifier(TorchShallowNeuralClassifier):\n",
    "\n",
    "    def build_graph(self):\n",
    "        return nn.Sequential(\n",
    "            nn.Linear(self.input_dim, self.n_classes_))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the data format and optimization process are the same as for `TorchShallowNeuralClassifier`, we needn't do anything beyond this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quick illustration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TorchSoftmaxClassifier(\n",
       "\tbatch_size=1028,\n",
       "\tmax_iter=1000,\n",
       "\teta=0.001,\n",
       "\toptimizer_class=<class 'torch.optim.adam.Adam'>,\n",
       "\tl2_strength=0,\n",
       "\tgradient_accumulation_steps=1,\n",
       "\tmax_grad_norm=None,\n",
       "\tvalidation_fraction=0.1,\n",
       "\tearly_stopping=False,\n",
       "\tn_iter_no_change=10,\n",
       "\twarm_start=False,\n",
       "\ttol=1e-05,\n",
       "\thidden_dim=50,\n",
       "\thidden_activation=Tanh())"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm_mod = TorchSoftmaxClassifier()\n",
    "\n",
    "sm_mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: as you can see here, this model will still accept keyword arguments `hidden_dim` and `hidden_activation`, which will be ignored since the graph doesn't use them. I'll leave this minor inconsistency aside."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1000 of 1000; error is 0.4604862630367279"
     ]
    }
   ],
   "source": [
    "_ = sm_mod.fit(X_cls_train, y_cls_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm_preds = sm_mod.predict(X_cls_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        19\n",
      "           1       1.00      0.93      0.97        15\n",
      "           2       0.94      1.00      0.97        16\n",
      "\n",
      "    accuracy                           0.98        50\n",
      "   macro avg       0.98      0.98      0.98        50\n",
      "weighted avg       0.98      0.98      0.98        50\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_cls_test, sm_preds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TorchModelBase` is able to [\"duck type\"](https://en.wikipedia.org/wiki/Duck_typing) standard `sklearn` estimators, so we can use the functionality from `sklearn.model_selection`. For example, here we use `sklearn.model_selection.cross_validate`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1000 of 1000; error is 0.57950609922409063"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'fit_time': array([1.48663092, 1.4696188 , 1.48087692, 1.51742196, 1.47723603]),\n",
       " 'score_time': array([0.00130987, 0.00146699, 0.00115705, 0.00122023, 0.00124574]),\n",
       " 'test_score': array([0.87777778, 0.63888889, 0.75925926, 0.51515152, 0.95681511])}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cross_validate(sm_mod, X_cls_train, y_cls_train, cv=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A deeper neural classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TorchShallowNeuralClassifier` is \"shallow\" in that it has just one hidden layer of representation. Adding a second is very straightforward. Again, all we really have to do is write a new `build_graph`, but the implementation below also includes a new `__init__` method to allow the user to separately control the sizes of the two hidden layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchDeeperNeuralClassifier(TorchShallowNeuralClassifier):\n",
    "    def __init__(self, hidden_dim1=50, hidden_dim2=50, **base_kwargs):\n",
    "        super().__init__(**base_kwargs)\n",
    "        self.hidden_dim1 = hidden_dim1\n",
    "        self.hidden_dim2 = hidden_dim2\n",
    "        # Good to remove this to avoid confusion:\n",
    "        self.params.remove(\"hidden_dim\")\n",
    "        # Add the new parameters to support model_selection using them:\n",
    "        self.params += [\"hidden_dim1\", \"hidden_dim2\"]\n",
    "\n",
    "    def build_graph(self):\n",
    "        return nn.Sequential(\n",
    "            nn.Linear(self.input_dim, self.hidden_dim1),\n",
    "            self.hidden_activation,\n",
    "            nn.Linear(self.hidden_dim1, self.hidden_dim2),\n",
    "            self.hidden_activation,\n",
    "            nn.Linear(self.hidden_dim2, self.n_classes_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TorchDeeperNeuralClassifier(\n",
       "\tbatch_size=1028,\n",
       "\tmax_iter=1000,\n",
       "\teta=0.001,\n",
       "\toptimizer_class=<class 'torch.optim.adam.Adam'>,\n",
       "\tl2_strength=0,\n",
       "\tgradient_accumulation_steps=1,\n",
       "\tmax_grad_norm=None,\n",
       "\tvalidation_fraction=0.1,\n",
       "\tearly_stopping=False,\n",
       "\tn_iter_no_change=10,\n",
       "\twarm_start=False,\n",
       "\ttol=1e-05,\n",
       "\thidden_activation=Tanh(),\n",
       "\thidden_dim1=50,\n",
       "\thidden_dim2=50)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deep_mod = TorchDeeperNeuralClassifier()\n",
    "\n",
    "deep_mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1000 of 1000; error is 0.02787744253873825"
     ]
    }
   ],
   "source": [
    "_ = deep_mod.fit(X_cls_train, y_cls_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "deep_preds = deep_mod.predict(X_cls_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        19\n",
      "           1       0.94      1.00      0.97        15\n",
      "           2       1.00      0.94      0.97        16\n",
      "\n",
      "    accuracy                           0.98        50\n",
      "   macro avg       0.98      0.98      0.98        50\n",
      "weighted avg       0.98      0.98      0.98        50\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_cls_test, deep_preds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To try to find optimal values for the hidden layer dimensionalities, we could do some hyperparameter tuning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1000 of 1000; error is 0.084474258124828345"
     ]
    }
   ],
   "source": [
    "xval = GridSearchCV(\n",
    "    TorchDeeperNeuralClassifier(),\n",
    "    param_grid={\n",
    "        'hidden_dim1': [5, 10],\n",
    "        'hidden_dim2': [5, 10]})\n",
    "\n",
    "_ = xval.fit(X_cls_train, y_cls_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9672889488678962"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xval.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TorchDeeperNeuralClassifier(\n",
       "\tbatch_size=1028,\n",
       "\tmax_iter=1000,\n",
       "\teta=0.001,\n",
       "\toptimizer_class=<class 'torch.optim.adam.Adam'>,\n",
       "\tl2_strength=0,\n",
       "\tgradient_accumulation_steps=1,\n",
       "\tmax_grad_norm=None,\n",
       "\tvalidation_fraction=0.1,\n",
       "\tearly_stopping=False,\n",
       "\tn_iter_no_change=10,\n",
       "\twarm_start=False,\n",
       "\ttol=1e-05,\n",
       "\thidden_activation=Tanh(),\n",
       "\thidden_dim1=5,\n",
       "\thidden_dim2=5)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xval.best_estimator_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also easy to write regression models. For these, we will `TorchModelBase`, since some fundamental things are different from the classifiers above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For illustrations, we'll use a random split of the [Boston house prices](https://scikit-learn.org/stable/datasets/index.html#boston-dataset) dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def boston_split():\n",
    "    dataset = fetch_california_housing()\n",
    "    X = dataset.data\n",
    "    y = dataset.target\n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        X, y, test_size=0.33, random_state=42)\n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_reg_train, X_reg_test, y_reg_train, y_reg_test = boston_split()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linear regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For linear regression, we create an `nn.Module` subclass:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchLinearRegressionModel(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.w = nn.Parameter(torch.zeros(self.input_dim))\n",
    "        self.b = nn.Parameter(torch.zeros(1))\n",
    "\n",
    "    def forward(self, X):\n",
    "        return X.matmul(self.w) + self.b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The estimator itself, a subclass of `TorchModelBase`, needs the following methods:\n",
    "\n",
    "* `build_graph`: to use `TorchLinearRegressionModel` from above.\n",
    "* `build_dataset`: for processing the data.\n",
    "* `predict`: for making predictions.\n",
    "* `score`: technically optional, but required for `sklearn.model_selection` usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchLinearRegresson(TorchModelBase):\n",
    "    def __init__(self, **base_kwargs):\n",
    "        super().__init__(**base_kwargs)\n",
    "        self.loss = nn.MSELoss(reduction=\"mean\")\n",
    "\n",
    "    def build_graph(self):\n",
    "        return TorchLinearRegressionModel(self.input_dim)\n",
    "\n",
    "    def build_dataset(self, X, y=None):\n",
    "        \"\"\"\n",
    "        This function will be used in training (when there is a `y`)\n",
    "        and in prediction (no `y`). For both cases, we rely on a\n",
    "        `TensorDataset`.\n",
    "        \"\"\"\n",
    "        X = torch.FloatTensor(X)\n",
    "        self.input_dim = X.shape[1]\n",
    "        if y is None:\n",
    "            dataset = torch.utils.data.TensorDataset(X)\n",
    "        else:\n",
    "            y = torch.FloatTensor(y)\n",
    "            dataset = torch.utils.data.TensorDataset(X, y)\n",
    "        return dataset\n",
    "\n",
    "    def predict(self, X, device=None):\n",
    "        \"\"\"\n",
    "        The `_predict` function of the base class handles all the\n",
    "        details around data formatting. In this case, the\n",
    "        raw output of `self.model`, as given by\n",
    "        `TorchLinearRegressionModel.forward` is all we need.\n",
    "        \"\"\"\n",
    "        return self._predict(X, device=device).cpu().numpy()\n",
    "\n",
    "    def score(self, X, y):\n",
    "        \"\"\"\n",
    "        Follow sklearn in using `r2_score` as the default scorer.\n",
    "        \"\"\"\n",
    "        preds = self.predict(X)\n",
    "        return r2_score(y, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TorchLinearRegresson(\n",
       "\tbatch_size=1028,\n",
       "\tmax_iter=1000,\n",
       "\teta=0.001,\n",
       "\toptimizer_class=<class 'torch.optim.adam.Adam'>,\n",
       "\tl2_strength=0,\n",
       "\tgradient_accumulation_steps=1,\n",
       "\tmax_grad_norm=None,\n",
       "\tvalidation_fraction=0.1,\n",
       "\tearly_stopping=False,\n",
       "\tn_iter_no_change=10,\n",
       "\twarm_start=False,\n",
       "\ttol=1e-05)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr = TorchLinearRegresson()\n",
    "\n",
    "lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 133. Training loss did not improve more than tol=1e-05. Final error is 9.422466397285461."
     ]
    }
   ],
   "source": [
    "_ = lr.fit(X_reg_train, y_reg_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_preds = lr.predict(X_reg_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.49816556936022427"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r2_score(y_reg_test, lr_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deeper Linear Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can extend the subclass we just created to easily create deeper regression models. Here's an example showing that all we need is the deeper `nn.Module` and a new `build_graph` method in the main estimator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchLinearRegressionModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, hidden_activation):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.hidden_activation = hidden_activation\n",
    "        self.input_layer = nn.Linear(self.input_dim, self.hidden_dim)\n",
    "        self.w = nn.Parameter(torch.zeros(self.hidden_dim))\n",
    "        self.b = nn.Parameter(torch.zeros(1))\n",
    "\n",
    "    def forward(self, X):\n",
    "        h = self.hidden_activation(self.input_layer(X))\n",
    "        return h.matmul(self.w) + self.b\n",
    "\n",
    "\n",
    "class TorchDeeperLinearRegression(TorchLinearRegresson):\n",
    "    def __init__(self, hidden_dim=20, hidden_activation=nn.Tanh(), **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.hidden_activation = hidden_activation\n",
    "        self.params += [\"hidden_dim\", \"hidden_activation\"]\n",
    "\n",
    "    def build_graph(self):\n",
    "        return TorchLinearRegressionModel(\n",
    "            input_dim=self.input_dim,\n",
    "            hidden_dim=self.hidden_dim,\n",
    "            hidden_activation=self.hidden_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TorchDeeperLinearRegression(\n",
       "\tbatch_size=1028,\n",
       "\tmax_iter=1000,\n",
       "\teta=0.001,\n",
       "\toptimizer_class=<class 'torch.optim.adam.Adam'>,\n",
       "\tl2_strength=0,\n",
       "\tgradient_accumulation_steps=1,\n",
       "\tmax_grad_norm=None,\n",
       "\tvalidation_fraction=0.1,\n",
       "\tearly_stopping=False,\n",
       "\tn_iter_no_change=10,\n",
       "\twarm_start=False,\n",
       "\ttol=1e-05,\n",
       "\thidden_dim=20,\n",
       "\thidden_activation=Tanh())"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deep_lr = TorchDeeperLinearRegression()\n",
    "\n",
    "deep_lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 31. Training loss did not improve more than tol=1e-05. Final error is 18.68722093105316."
     ]
    }
   ],
   "source": [
    "_ = deep_lr.fit(X_reg_train, y_reg_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "deep_lr_preds = deep_lr.predict(X_reg_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.0004565544982604308"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r2_score(y_reg_test, deep_lr_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN sequence labeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a final illustrative example, let's make use of our existing RNN classifier components to create a model that can do full sequence labeling. PyTorch's abstractions concerning how layers interact and how loss functions work make this surprisingly easy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For examples, we'll use the CoNLL 2002 shared task on named entity labeling in Spanish. NLTK provides an easy interface:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package conll2002 to\n",
      "[nltk_data]     /Users/cgpotts/Documents/data/nltk_data/...\n",
      "[nltk_data]   Package conll2002 is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ensure that the CoNLL dataset is present. If this function call\n",
    "# raises errors, you might need to go through the steps of\n",
    "# installing the NLTK data: https://www.nltk.org/data.html\n",
    "\n",
    "nltk.download(\"conll2002\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sequence_dataset():\n",
    "    train_seq = nltk.corpus.conll2002.iob_sents('esp.train')\n",
    "    X = [[x[0] for x in seq] for seq in train_seq]\n",
    "    y = [[x[2] for x in seq] for seq in train_seq]\n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        X, y, test_size=0.33, random_state=42)\n",
    "    vocab = sorted({w for seq in X_train for w in seq}) + [\"$UNK\"]\n",
    "    return X_train, X_test, y_train, y_test, vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    " X_seq_train, X_seq_test, y_seq_train, y_seq_test, seq_vocab = sequence_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's are the first few tokens in the first training example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['La', 'compañía', 'estatal', 'de', 'electricidad', 'de', 'Suecia', ',']"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_seq_train[0][: 8]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And the corresponding labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_seq_train[0][: 8]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start with the `nn.Module` subclass we need. In `torch_rnn_classifier.py`, we already have a pretty generic RNN module: `TorchRNNModel`. For the classifier use, `TorchRNNClassifierModel` uses the output of `TorchRNNModel` to define a classifier based on the final output state. For sequence labeling, we drop `TorchRNNClassifierModel` and replace it with model that has a classifier on every output state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchSequenceLabeler(nn.Module):\n",
    "    def __init__(self, rnn, output_dim):\n",
    "        super().__init__()\n",
    "        self.rnn = rnn\n",
    "        self.output_dim = output_dim\n",
    "        if self.rnn.bidirectional:\n",
    "            self.classifier_dim = self.rnn.hidden_dim * 2\n",
    "        else:\n",
    "            self.classifier_dim = self.rnn.hidden_dim\n",
    "        self.classifier_layer = nn.Linear(\n",
    "            self.classifier_dim, self.output_dim)\n",
    "\n",
    "    def forward(self, X, seq_lengths):\n",
    "        outputs, state = self.rnn(X, seq_lengths)\n",
    "        outputs, seq_length = torch.nn.utils.rnn.pad_packed_sequence(\n",
    "            outputs, batch_first=True)\n",
    "        logits = self.classifier_layer(outputs)\n",
    "        # During training, we need to swap the dimensions of logits\n",
    "        # to accommodate `nn.CrossEntropyLoss`:\n",
    "        if self.training:\n",
    "            return logits.transpose(1, 2)\n",
    "        else:\n",
    "            return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We won't normally interact with this module directly, but it's perhaps instructive to see how it works on its own:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 4\n",
    "\n",
    "seq_rnn = TorchRNNModel(vocab_size, embed_dim=4, hidden_dim=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_module = TorchSequenceLabeler(seq_rnn, vocab_size)\n",
    "\n",
    "_ = seq_module.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "toy_seqs = torch.LongTensor([[0,1,2], [0,2,1]])\n",
    "\n",
    "seq_lengths = torch.LongTensor([3,3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This should return two sequences of 4-dimensional vectors – the per-token logits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.3431,  0.2043, -0.1963,  0.1925],\n",
       "         [-0.4260,  0.1578, -0.2943,  0.2358],\n",
       "         [-0.4403,  0.1517, -0.2393,  0.2148]],\n",
       "\n",
       "        [[-0.3431,  0.2043, -0.1963,  0.1925],\n",
       "         [-0.4059,  0.1804, -0.2166,  0.2285],\n",
       "         [-0.4345,  0.1526, -0.2976,  0.2378]]], grad_fn=<ViewBackward0>)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_module(toy_seqs, seq_lengths)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The remaining tasks concern the new estimator. We need to define the following methods:\n",
    "\n",
    "* `build_graph`: to use `TorchSequenceLabeler`\n",
    "* `build_dataset`: just like what we need for a classifier, but it has to deal with examples as full sequences.\n",
    "* `predict_proba`: like a classifier `predict_proba`, but it needs to remove any sequence padding and deal with full sequences\n",
    "* `predict`: just like a classifier `predict` method, but defined for sequences.\n",
    "* `score`: also very much like a classifier `score` function but designed to deal with sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchRNNSequenceLabeler(TorchRNNClassifier):\n",
    "\n",
    "    def build_graph(self):\n",
    "        rnn = TorchRNNModel(\n",
    "            vocab_size=len(self.vocab),\n",
    "            embedding=self.embedding,\n",
    "            use_embedding=self.use_embedding,\n",
    "            embed_dim=self.embed_dim,\n",
    "            rnn_cell_class=self.rnn_cell_class,\n",
    "            hidden_dim=self.hidden_dim,\n",
    "            bidirectional=self.bidirectional,\n",
    "            freeze_embedding=self.freeze_embedding)\n",
    "        model = TorchSequenceLabeler(\n",
    "            rnn=rnn,\n",
    "            output_dim=self.n_classes_)\n",
    "        self.embed_dim = rnn.embed_dim\n",
    "        return model\n",
    "\n",
    "    def build_dataset(self, X, y=None):\n",
    "        X, seq_lengths = self._prepare_sequences(X)\n",
    "        if y is None:\n",
    "            return TorchRNNDataset(X, seq_lengths)\n",
    "        else:\n",
    "            # These are the changes from a regular classifier. All\n",
    "            # concern the fact that our labels are sequences of labels.\n",
    "            self.classes_ = sorted({x for seq in y for x in seq})\n",
    "            self.n_classes_ = len(self.classes_)\n",
    "            class2index = dict(zip(self.classes_, range(self.n_classes_)))\n",
    "            # `y` is a list of tensors of different length. Our Dataset\n",
    "            # class will turn it into a padding tensor for processing.\n",
    "            y = [torch.tensor([class2index[label] for label in seq])\n",
    "                 for seq in y]\n",
    "            return TorchRNNDataset(X, seq_lengths, y)\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        seq_lengths = [len(ex) for ex in X]\n",
    "        # The base class does the heavy lifting:\n",
    "        preds = self._predict(X)\n",
    "        # Trim to the actual sequence lengths:\n",
    "        preds = [p[: l] for p, l in zip(preds, seq_lengths)]\n",
    "        # Use `softmax`; the model doesn't do this because the loss\n",
    "        # function does it internally.\n",
    "        probs = [torch.softmax(seq, dim=1) for seq in preds]\n",
    "        return probs\n",
    "\n",
    "    def predict(self, X):\n",
    "        probs = self.predict_proba(X)\n",
    "        return [[self.classes_[i] for i in seq.argmax(axis=1)] for seq in probs]\n",
    "\n",
    "    def score(self, X, y):\n",
    "        preds = self.predict(X)\n",
    "        flat_preds = [x for seq in preds for x in seq]\n",
    "        flat_y = [x for seq in y for x in seq]\n",
    "        return utils.safe_macro_f1(flat_y, flat_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_mod = TorchRNNSequenceLabeler(\n",
    "    seq_vocab,\n",
    "    early_stopping=True,\n",
    "    eta=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 19. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 8.503544092178345"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 11min 33s, sys: 2min 29s, total: 14min 3s\n",
      "Wall time: 2min 16s\n"
     ]
    }
   ],
   "source": [
    "%time _ = seq_mod.fit(X_seq_train, y_seq_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.11581991000549478"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_mod.score(X_seq_test, y_seq_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
